ENH: add partition and argpartition functions#449
ENH: add partition and argpartition functions#449lucascolley merged 15 commits intodata-apis:mainfrom
partition and argpartition functions#449Conversation
partition and argpartition functions
see array-api-extra/tests/test_funcs.py Line 1183 in ca20f03 |
lucascolley
left a comment
There was a problem hiding this comment.
thanks @cakedev0 !
Looks like there is also a merge conflict now.
src/array_api_extra/_delegation.py
Outdated
| kth += 1 # HACK: we use a non-specified behavior of torch.topk: | ||
| # in `a_left`, the element in the last position is the max | ||
| a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False) |
There was a problem hiding this comment.
hmmm, I would rather not rely on undocumented behaviour. Is there an alternative?
There was a problem hiding this comment.
Fair ^^
Three options:
- add an
assert a_left.max() == a_left[k] - We can just re-run the same logic with
kth=1andlargest=True. Impact on perfs is probably 10 to 100% slower depending on the input. But it doens't add a lot of logic - We can do a
if a_left.max() != a_left[k]: swap_max_with_last_element(a_left, axis=-1)=> requires to implementswap_max_with_last_element(and the equivalent for argsort).
I vote for 1 because I'm lazy but I like perf :p
There was a problem hiding this comment.
Edit: wait I need to rethink something about numpy.partition specs...
There was a problem hiding this comment.
So! I rewrote entirely this section, it now relies on torch.kthvalue and is very aligned with numpy's behavior.
On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).
I will maybe open an issue on numpy to ask for some clarification.
There was a problem hiding this comment.
On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).
It might be worth contributing this consideration to the array API spec discussion:
lucascolley
left a comment
There was a problem hiding this comment.
thanks @cakedev0, looks close!
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
|
Thanks for the reactive and helpful reviews @lucascolley Sorry for the numpy docs style details, I'll make sure to read the doc about this carefully before opening another PR 😉 |
lucascolley
left a comment
There was a problem hiding this comment.
thanks @cakedev0, let's merge it! And thanks for taking a look too @ogrisel .
Would be great to follow-up on #449 (comment).
Are you interested in taking over gh-341 next?
|
@lucascolley I see that this as been milestoned for 0.9.1, but those are new functions. Wouln'd it make sense to make them part of a new 0.10.0 release instead? |
we use https://jacobtomlinson.dev/effver/, so no. Unless you see any problems with that! |
|
Ok, as you wish. |
Closes #448
Supports nd-arrays.
The crux of this PR is the torch part: transforming
kthvalueoutputs to a partition/argpartition. The rest is mostly wrappers and checks.